Keras останавливается после fit_generator без ошибок - PullRequest
0 голосов
/ 03 февраля 2020

Я играю с архитектурами Re snet, чтобы сравнить производительность с CNN. Я использовал этот Re snet для моего первого теста. Я повторно использую свой код для загрузки и подготовки данных для сети, и он работает просто отлично. Кажется, что остальная часть сценария работает нормально, за исключением случаев, когда он попадает в fit_generator. В fit_generator он делает паузу на некоторое время, затем кажется, что выходит, где у меня есть печатная инструкция, говорящая "что случилось?" Я запутался, поскольку ожидал, что сообщение об ошибке или программа взломают sh или что-то еще Я использую windows 10 под управлением последней версии Anaconda. В моей среде я использую python 3.6, последнюю версию Keras 2.3, последнюю версию TensorFlow. Буду признателен за любые идеи.

def batch_generator(X_train, Y_train):  
    while True:
        for fl, lb in zip(X_train, Y_train):
            sam, lam = get_IQsamples(fl, lb)
            max_iter = sam.shape[0]
            sample = []     # store all the generated data batches
            label = []   # store all the generated label batches

            i = 0
            for d, l in zip(sam, lam):
                sample.append(d)
                label.append(l)
                i += 1
                if i == max_iter:
                    break
            sample = np.asarray(sample)        
            label = np.asarray(label)
            yield sample, label


def residual_stack(x, f):
    
    # 1x1 conv linear
    x = Conv2D(f, (1, 1), strides=1, padding='same', data_format='channels_last')(x)
    x = Activation('linear')(x)


    # residual unit 1    
    x_shortcut = x
    x = Conv2D(f, (3, 2), strides=1, padding="same", data_format='channels_last')(x)
    x = Activation('relu')(x)
    x = Conv2D(f, 3, strides=1, padding="same", data_format='channels_last')(x)
    x = Activation('linear')(x)

    # add skip connection
    if x.shape[1:] == x_shortcut.shape[1:]:
      x = Add()([x, x_shortcut])

    else:
      raise Exception('Skip Connection Failure!')


    # residual unit 2    
    x_shortcut = x
    x = Conv2D(f, 3, strides=1, padding="same", data_format='channels_last')(x)
    x = Activation('relu')(x)
    x = Conv2D(f, 3, strides = 1, padding = "same", data_format='channels_last')(x)
    x = Activation('linear')(x)

    # add skip connection
    if x.shape[1:] == x_shortcut.shape[1:]:
      x = Add()([x, x_shortcut])

    else:
      raise Exception('Skip Connection Failure!')


    # max pooling layer
    x = MaxPooling2D(pool_size=2, strides=None, padding='valid', data_format='channels_last')(x)

    return x

.

Define Re sNet Модель

# define resnet model

def ResNet(input_shape, classes):   

    # create input tensor
    x_input = Input(input_shape)
    x = x_input

    # residual stack
    num_filters = 40
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)


    # output layer
    x = Flatten()(x)
    x = Dense(128, activation="selu", kernel_initializer="he_normal")(x)
    x = Dropout(.5)(x)
    x = Dense(128, activation="selu", kernel_initializer="he_normal")(x)
    x = Dropout(.5)(x)
    x = Dense(classes , activation='softmax', kernel_initializer = glorot_uniform(seed=0))(x)


    # Create model
    model = Model(inputs = x_input, outputs = x)
    model.summary()

    return model


model = ResNet((32,32,2),8)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


print('Load complete!')
print('\n')


steps = val_length_train // batchsize
valid_steps = val_length // batchsize

history = model.fit_generator(
            generator=train_gen,
            epochs=3,
            verbose=0,
            steps_per_epoch=steps,
            validation_data=valid_gen,
            validation_steps=valid_steps,
            callbacks=[tensorboard])

print("what happened?")

1 Ответ

0 голосов
/ 04 февраля 2020

Вроде. Если есть ошибка, она все равно будет выдана и напечатана с подробным значением 0. Это, как говорится, подробный 0, кажется, вызывает проблемы для некоторых людей. Это сообщение за 2017 год, но я видел ту же проблему, что и недавно, ноябрь 2019 https://github.com/keras-team/keras/issues/5818. Если я использую 0 или 2, все работает нормально, но все это не имеет значения, поскольку сценарий никогда не начинает захватывать данные или тренироваться. Я ценю обратную связь.

...