Как обучить модель с двумя детскими функциями для обнаружения объектов? - PullRequest
0 голосов
/ 29 июня 2018

Я пытаюсь реализовать модель, описанную профессором Эндрю Нгом для обнаружения объекта (объяснение начинается в 10:00).

Он описывает первый элемент выходного вектора как вероятность того, что объект был обнаружен, за которым следуют координаты ограничивающего прямоугольника сопоставляемого объекта (когда сопоставляется один). Последняя часть выходного вектора - это softmax всех классов, которые знает ваша модель.

Как он это объясняет, использование простой квадратичной ошибки для случая, когда есть обнаружение, хорошо, и только разность квадратов y^[0] - y[0]. Я понимаю, что это наивный подход. Я просто хочу реализовать это для обучения.

Мои вопросы

  1. Как реализовать эту условную потерю в тензорном потоке?
  2. Как мне справиться с этим условием около y^[0] при работе с партией.

1 Ответ

0 голосов
/ 30 июня 2018

Как реализовать эту условную потерю в тензорном потоке?

Вы можете преобразовать функцию потерь в:

Error = mask[0]*(y^[0]-y[0])**2 + mask[1]*(y^[1]-y[1])**2 ... mask[n]*(y^[n]-y[n])**2),
where mask = [1, 1,...1] for y[0] = 1 and [1, 0, ...0] for y[0] = 0

Как мне справиться с этим условием около y ^ [0] при работе с партия.

Для партии вы можете создать маску на лету, как:

mask = tf.concat([tf.ones((tf.shape(y)[0],1)),y[:,0][...,None]*y[:,1:]], axis=1)

Код:

y_hat_n = np.array([[3, 3, 3, 3], [3,3,3,3]])
y_1 = np.array([[1, 1, 1, 1], [1,1,1,1]])
y_0 = np.array([[0, 1, 1, 1], [0,1,1,1]])


y = tf.placeholder(tf.float32,[None, 4])
y_hat = tf.placeholder(tf.float32,[None, 4])
mask = tf.concat([tf.ones((tf.shape(y)[0],1)),y[:,0][...,None]*y[:,1:]], axis=1)

error = tf.losses.mean_squared_error(mask*y, mask*y_hat)

with tf.Session() as sess:

   print(sess.run([mask,error], {y:y_0, y_hat:y_hat_n}))
   print(sess.run([mask,error], {y:y_1, y_hat:y_hat_n}))

# Mask and error
#[array([[1., 0., 0., 0.],
#   [1., 0., 0., 0.]], dtype=float32), 2.25]

#[array([[1., 1., 1., 1.],
#   [1., 1., 1., 1.]], dtype=float32), 4.0]
...