Я пытаюсь сделать resnet18 для данных CIFAR100, как показано ниже:
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(64, (3,3), strides = (1,1), padding = 'same', activation='linear',input_shape=x_train.shape[1:])
self.bn1 = BatchNormalization()
self.RL = ReLU()
self.avg = AveragePooling2D((4,4))
self.flatten = Flatten()
self.FC = Dense(100, activation='softmax')
def make_BB(self, x,num_filter, size_decrease):
if size_decrease == True:
C1 = Conv2D(num_filter, (3,3),padding = 'same',strides = (2,2),activation = 'linear')
B1 = BatchNormalization()
R1 = ReLU()
C2 = Conv2D(num_filter, (3,3),padding = 'same',strides = (1,1),activation = 'linear')
B2 = BatchNormalization()
C3 = Conv2D(num_filter, (1,1),padding= 'same', strides = (2,2),activation = 'linear')
B3 = BatchNormalization()
forward = B2(C2(R1(B1(C1(x))))) + B3(C3(x))
return forward
else:
C1 = Conv2D(num_filter, (3,3),padding = 'same',strides = (1,1),activation = 'linear')
B1 = BatchNormalization()
R1 = ReLU()
C2 = Conv2D(num_filter, (3,3),padding = 'same',strides = (1,1),activation = 'linear')
B2 = BatchNormalization()
forward = B2(C2(R1(B1(C1(x))))) + x
return forward
def call(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.RL(x)
# BasicBlock1
x = self.make_BB(x,64,False)
x = self.make_BB(x,64,False)
# BasicBlock2
x = self.make_BB(x,128,True)
x = self.make_BB(x,128,False)
# BasicBlock3
x = self.make_BB(x,256,True)
x = self.make_BB(x,256,False)
# BasicBlock4
x = self.make_BB(x,512,True)
x = self.make_BB(x,512,False)
x = self.avg(x)
x = self.flatten(x)
return self.FC(x)
def func(self):
x = tf.keras.layers.Input(shape=(32, 32, 3))
return Model(inputs=[x], outputs=self.call(x))
model = MyModel()
optimizer = tf.keras.optimizers.Adam(0.01)
model.compile(optimizer= optimizer,
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
model.fit(x_train,y_train, epochs = 50)
Но эта модель записывает 1% точности, что означает, что она ничего не узнает.
Я пробовал более простые модели , и это сработало.
Я также пытался изменить скорость обучения на 0,1, 0,001, 0,0005 и т.д. c, но все результаты были такими же.
Я думаю, что часть make_BB делает что-то не так , но я не знаю, что я сделал не так, и даже не могу найти свои ошибки.
Что я сделал не так?