Многоуровневая сегментация U-Net в TensorFlow - PullRequest
0 голосов
/ 04 января 2019

Я видел варианты этого вопроса повсюду, но все еще изо всех сил пытаюсь его правильно реализовать. У меня есть МРТ-изображения головного мозга с сегментированными масками, основанными на правде, с 4 классами (0-фон, 1-тип ткани 1, 2-тип ткани 2, 3-необъяснимо пропущено, и 4-тип ткани 4 ... BrATs набор данных)

enter image description here

У меня реализована базовая архитектура U-Net, но мне не удается расширить ее до недвоичной классификации. В частности, функция потерь.

Это то, что я реализовал, но я явно упускаю из виду важные детали:

[...]
output = tf.layers.conv2d_transpose(
conv18,
filters=5,
kernel_size=1,
strides=1,
padding='same',
data_format='channels_last',
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=tf.contrib.layers.l2_regularizer(reg),
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name='output',
reuse=None
)

Я думал, что 5 фильтров для (0,1,2,3,4) возможных значений маски будут правильными. Затем я использовал следующую функцию потерь:

loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
_sentinel=None,
labels=label,
logits=output,
name='cross_ent_loss'
)

return tf.reduce_mean(loss)

Где логиты будут проходить через вывод сверху, а метки будут моими сложенными изображениями маски [n_batch, x_dim, y_dim, 1]. Глядя на документацию, я знаю, что не пропускаю метки правильного тензора.

Я даже правильно говорю об этом? Как реализовать потерю с помощью мультиклассовых меток, содержащихся в изображении 1 маски?

1 Ответ

0 голосов
/ 24 января 2019

Что-то, что я упустил из документации tf.nn.sparse_softmax_cross_entropy_with_logits

метки: тензор формы [d_0, d_1, ..., d_ {r-1}] (где r - ранг меток и результата) и dtype int32 или int64. Каждая запись в метках должна быть индексом в [0, num_classes). Другие значения будут вызывать исключение при выполнении этой операции на ЦП и возвращать NaN для соответствующих строк потерь и градиента на графическом процессоре.

Таким образом, меняя метки по форме [-1]

label_reshape = tf.reshape(label, [-1])

исправил это!

...