Я имею дело с проблемой регрессии, когда для данного изображения я хочу предсказать значение 3 параметров (декартовых координат). Для одного изображения у меня может быть несколько приемлемых координат. Для этого я использую нейронную сеть с использованием керас. Чтобы обучить свою сеть, я хочу реализовать пользовательскую функцию потерь, которая будет вычислять евклидово расстояние между прогнозом и ближайшим приемлемым значением. В математическом выражении это можно выразить так:
Форма равна , а форма моей цели равна .
Чтобы рассчитать эту потерю, я сначала изменяю , чтобы иметь правильную форму. Затем я выполняю расчет потерь (используя тензор потока 1.13):
import tensorflow as tf
import tensorflow.keras.backend as K
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
K.set_session(tf.Session(config=config))
from tensorflow.python.keras.applications import ResNet50
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Dropout
import numpy as np
def min_mse(y_pred, y_true):
y_pred_temp = K.repeat(y_pred, K.shape(y_true)[1])
return K.min(K.sum(K.sqrt(y_pred_temp-y_true), axis=-1), axis=-1)
def resnet_model():
model = Sequential()
model.add(ResNet50(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(rate=0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(rate=0.2))
model.add(Dense(3, activation='linear'))
model.layers[0].trainable = False
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), loss=min_mse)
return model
X = np.random.random((200, 224, 224, 3))
Y = np.random.random((200, 10, 3))
model = resnet_model()
model.fit(X, Y)
Однако этот код выдает ошибку
tenorflow. python .framework.errors_impl.InvalidArgumentError: Ожидается Аргумент умножается на вектор длины 4, но получил длину 3 [[{{node loss / dens_2_loss / Tile}}]]
У меня возникли проблемы с решением, так как я не могу легко напечатать форму из моих разных тензоров, чтобы понять проблему. Есть ли у вас какие-либо подсказки, как решить эту проблему (исправляя мой код или используя другой метод)? Заранее спасибо.