Модель сегментации производит NaN на выходе - PullRequest
1 голос
/ 17 июня 2020

Я пытаюсь выполнить сегментацию почек и опухолей почки на наборе данных kits19 с использованием u- net. Я загружаю каждый случай (3D-данные), MinMaxScaling для них, а затем генерирую партии по 8 штук, то же самое с данными проверки. Нет никакой предварительной обработки, кроме масштабирования. Затем я загружаю данные на u- net с пакетной нормализацией после каждого сверточного слоя. И количество эпох, как в этом коде:

NAME = "unet_without_edge_detection_with_batch_norm_{}".format(int(time.time()))

callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='logs/{}'.format(NAME)),
    tf.keras.callbacks.ModelCheckpoint(filepath='saved_models/checkpoints/{}'.format(NAME),
        save_weights_only=False, monitor='val_acc', mode='max', save_best_only=False)
    ]

# '------lists data points for train and validation'
# deleting broken cases
case_numbers = np.delete(np.arange(0, 209, 1), [158, 159, 170, 202])
case_numbers_val = case_numbers[179:]
case_numbers_train = case_numbers[:179]

batch_size = 8
data_size = 38650
val_data_size = 4520
results = model.fit(
    trainGenerator(case_numbers_train, batch_size),
    steps_per_epoch = int(np.floor( data_size / (batch_size) )),
    epochs = 16,
    callbacks = callbacks,
    verbose = 1,
    validation_data = validationGenerator(case_numbers_val, batch_size),
    validation_steps = int(np.floor( val_data_size / (batch_size) ))
    )

model.save('saved_models\model_from_{}'.format(time.time()))

Во время тренировки потери внезапно падают до 1,1921e-7 и остаются такими, в то время как точность колеблется около 0,99 (что, я думаю, нормально, поскольку большинство из них не содержат любую почку). Когда я пытаюсь что-то предсказать с помощью этой модели, она выводит массив NaN.

Вот сводка сети (я использую relu во всех слоях и сигмоид на выходе):

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 512, 512, 1) 0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 512, 512, 1)  0           input_1[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 16) 160         lambda[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 16) 64          conv2d[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, 512, 512, 16) 0           batch_normalization[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 16) 2320        dropout[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 512, 512, 16) 64          conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 256, 256, 16) 0           batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 32) 4640        max_pooling2d[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 256, 256, 32) 128         conv2d_2[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 256, 256, 32) 0           batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 256, 256, 32) 9248        dropout_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 256, 256, 32) 128         conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 128, 128, 32) 0           batch_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 64) 18496       max_pooling2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 128, 128, 64) 256         conv2d_4[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 128, 128, 64) 0           batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 128, 128, 64) 36928       dropout_2[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 128, 128, 64) 256         conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 64, 64, 64)   0           batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 128)  73856       max_pooling2d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 64, 64, 128)  512         conv2d_6[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 64, 64, 128)  0           batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 128)  147584      dropout_3[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 64, 64, 128)  512         conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 32, 32, 128)  0           batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  295168      max_pooling2d_3[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 256)  1024        conv2d_8[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 32, 32, 256)  0           batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 32, 32, 256)  590080      dropout_4[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 32, 32, 256)  1024        conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 64, 64, 128)  131200      batch_normalization_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 64, 64, 256)  0           conv2d_transpose[0][0]
                                                                 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 64, 64, 128)  295040      concatenate[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 64, 64, 128)  512         conv2d_10[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 64, 64, 128)  0           batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 64, 64, 128)  147584      dropout_5[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 128)  512         conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 128, 128, 64) 32832       batch_normalization_11[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 128, 128, 128 0           conv2d_transpose_1[0][0]
                                                                 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 128, 128, 64) 73792       concatenate_1[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 128, 128, 64) 256         conv2d_12[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 128, 128, 64) 0           batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 128, 128, 64) 36928       dropout_6[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 128, 128, 64) 256         conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 256, 256, 32) 8224        batch_normalization_13[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 256, 256, 64) 0           conv2d_transpose_2[0][0]
                                                                 batch_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 256, 256, 32) 18464       concatenate_2[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 256, 256, 32) 128         conv2d_14[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 256, 256, 32) 0           batch_normalization_14[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 256, 256, 32) 9248        dropout_7[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 256, 256, 32) 128         conv2d_15[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 512, 512, 16) 2064        batch_normalization_15[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 512, 512, 32) 0           conv2d_transpose_3[0][0]
                                                                 batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 512, 512, 16) 4624        concatenate_3[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 512, 512, 16) 64          conv2d_16[0][0]
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 512, 512, 16) 0           batch_normalization_16[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 512, 512, 16) 2320        dropout_8[0][0]
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 512, 512, 16) 64          conv2d_17[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 512, 512, 3)  51          batch_normalization_17[0][0]
==================================================================================================
Total params: 1,946,739
Trainable params: 1,943,795
Non-trainable params: 2,944
def soft_dice_loss(y_true, y_pred, epsilon=1e-6): 
    axes = tuple(range(1, len(y_pred.shape)-1)) 
    numerator = 2. * np.sum(y_pred * y_true, axes)
    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)

    return 1 - np.mean(numerator / (denominator + epsilon))

optimizer = tf.keras.optimizers.Adam()
    model.compile(optimizer=optimizer, loss =['categorical_crossentropy', soft_dice_loss], metrics=['accuracy'])

А вот и генератор:

def trainGenerator(case_nums, batch_size):
    #takes data from begining of data set to 
    while True:
        for case_num in case_nums:
            volume, segmentation = load_case(case_num)
            #preprocessing input
            X_file = preprocess_X(volume)
            y_file = preprocess_y(segmentation)
            L = X_file.shape[0]
            batch_start = 0
            batch_end = batch_size
            while batch_start < L:
                limit = min(batch_end, L)
                X = X_file[batch_start:limit, :, :, :]
                y = y_file[batch_start:limit, :, :, :]

                yield (X.astype(np.float32), y.astype(np.float32))            
                batch_start += batch_size   
                batch_end += batch_size

            if case_num == case_nums[-1]:
                # breaks loop so it starts again infinietly
                break

Я попытался изменить функцию активации на выходном слое на softmax и tanh, optimizer и попытался без bash нормализации, и это ничего не изменило. В чем может быть проблема, если в данных нет nans? Я использую Tensorflow 2.2.0

...