Пользовательская функция потерь Tensorflow Keras для доступа к тензорным каналам - PullRequest
0 голосов
/ 08 марта 2020

У меня есть 2-канальный массив numpy формы (64, 64, 2) в качестве входных данных для моего CNN. Я хочу создать настраиваемую функцию потерь, как описано в https://www.tensorflow.org/guide/keras/train_and_evaluate:

def basic_loss_function(y_true, y_pred):
    return tf.math.reduce_mean(tf.abs(y_true - y_pred))

model.compile(optimizer=keras.optimizers.Adam(),
              loss=basic_loss_function)

model.fit(x_train, y_train, batch_size=64, epochs=3)

Но я хочу кое-что более сложное, чем эта базовая c. То, что мне нужно, это сделать обратное ДПФ (ifft2d), и мои y_pred и y_true, как ожидается, будут иметь форму (64, 64, 2), причем 2 канала являются действительной и мнимой частями fft2. Как я могу правильно получить доступ к каналам y_pred и y_true (я полагаю, что это какой-то слой керас / тензор), чтобы перестроить комплексное число в виде RealPart + 1j * ImagPart (в numpy это будет быть y_pred [:,:, 0] и y_pred [:,:, 1])?

-> Итак, кто-то точно знает, что это за объекты y_pred и y_true и как получить доступ к их каналам / элементы? (Это не так легко отладить, так как его нужно запускать в скомпилированном CNN, поэтому лучше знать это заранее)

1 Ответ

1 голос
/ 08 марта 2020

y_true и y_pred являются тензорами формы (batchsize, ...[output shape]...). Ваш вход имеет форму (64,64,2), но я не уверен, как выглядит ваш вывод, если ваш вывод действительно (64,64,2), то y_pred или y_true имеют форму (64,64,64,2), учитывая ваш batchsize=64.

Работа с Tensors очень похожа на синтаксис numpy, поэтому вы можете использовать нотацию среза с тензорами, например, y_true[:,:,:,0] (обратите внимание на добавленное измерение пакета).

Tensorflow имеет функции для вычисления DFT, FFT , .. и др c. См. tf.signal и tf.signal.rfft2d

Если ваша функция потерь включает операции на входе, а не только на выходах y_true и y_pred, тогда вы можете использовать model.add_loss вместо model.compile(loss= basic_loss_function) следующим образом

x = Input(shape=(64,64,2))
y_true = Input(shape=...))
# your CNN layers
y_pred = Dense(128)(net)

model = Model(input=[x, y_true], output=output)
model.add_loss(basic_loss_function(x, y_true, y_pred))

обратите внимание, что метки (также известные как y_true) теперь являются входными данными для модели.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...