Я играю с архитектурами Re snet, чтобы сравнить производительность с CNN. Я использовал этот Re snet для моего первого теста. Я повторно использую свой код для загрузки и подготовки данных для сети, и он работает просто отлично. Кажется, что остальная часть сценария работает нормально, за исключением случаев, когда он попадает в fit_generator. В fit_generator он делает паузу на некоторое время, затем кажется, что выходит, где у меня есть печатная инструкция, говорящая "что случилось?" Я запутался, поскольку ожидал, что сообщение об ошибке или программа взломают sh или что-то еще Я использую windows 10 под управлением последней версии Anaconda. В моей среде я использую python 3.6, последнюю версию Keras 2.3, последнюю версию TensorFlow. Буду признателен за любые идеи.
def batch_generator(X_train, Y_train):
while True:
for fl, lb in zip(X_train, Y_train):
sam, lam = get_IQsamples(fl, lb)
max_iter = sam.shape[0]
sample = [] # store all the generated data batches
label = [] # store all the generated label batches
i = 0
for d, l in zip(sam, lam):
sample.append(d)
label.append(l)
i += 1
if i == max_iter:
break
sample = np.asarray(sample)
label = np.asarray(label)
yield sample, label
def residual_stack(x, f):
# 1x1 conv linear
x = Conv2D(f, (1, 1), strides=1, padding='same', data_format='channels_last')(x)
x = Activation('linear')(x)
# residual unit 1
x_shortcut = x
x = Conv2D(f, (3, 2), strides=1, padding="same", data_format='channels_last')(x)
x = Activation('relu')(x)
x = Conv2D(f, 3, strides=1, padding="same", data_format='channels_last')(x)
x = Activation('linear')(x)
# add skip connection
if x.shape[1:] == x_shortcut.shape[1:]:
x = Add()([x, x_shortcut])
else:
raise Exception('Skip Connection Failure!')
# residual unit 2
x_shortcut = x
x = Conv2D(f, 3, strides=1, padding="same", data_format='channels_last')(x)
x = Activation('relu')(x)
x = Conv2D(f, 3, strides = 1, padding = "same", data_format='channels_last')(x)
x = Activation('linear')(x)
# add skip connection
if x.shape[1:] == x_shortcut.shape[1:]:
x = Add()([x, x_shortcut])
else:
raise Exception('Skip Connection Failure!')
# max pooling layer
x = MaxPooling2D(pool_size=2, strides=None, padding='valid', data_format='channels_last')(x)
return x
.
Define Re sNet Модель
# define resnet model
def ResNet(input_shape, classes):
# create input tensor
x_input = Input(input_shape)
x = x_input
# residual stack
num_filters = 40
x = residual_stack(x, num_filters)
x = residual_stack(x, num_filters)
x = residual_stack(x, num_filters)
x = residual_stack(x, num_filters)
x = residual_stack(x, num_filters)
# output layer
x = Flatten()(x)
x = Dense(128, activation="selu", kernel_initializer="he_normal")(x)
x = Dropout(.5)(x)
x = Dense(128, activation="selu", kernel_initializer="he_normal")(x)
x = Dropout(.5)(x)
x = Dense(classes , activation='softmax', kernel_initializer = glorot_uniform(seed=0))(x)
# Create model
model = Model(inputs = x_input, outputs = x)
model.summary()
return model
model = ResNet((32,32,2),8)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
print('Load complete!')
print('\n')
steps = val_length_train // batchsize
valid_steps = val_length // batchsize
history = model.fit_generator(
generator=train_gen,
epochs=3,
verbose=0,
steps_per_epoch=steps,
validation_data=valid_gen,
validation_steps=valid_steps,
callbacks=[tensorboard])
print("what happened?")