Написание Керас Модельный класс - PullRequest
0 голосов
/ 28 февраля 2020

Я хочу переписать код ниже как класс:

    input = Input(shape=(28, 28, 1))
    label = Input(shape=(10,))

    x = Conv2D(32, kernel_size=(3, 3), activation='relu')(input)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    x = Flatten()(x)
    x = Dense(512, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    output = ArcFace(num_classes=10)([x, label])

    model = Model([input, label], output)

    model.compile(loss='categorical_crossentropy',
              optimizer=Adam(),
              metrics=['accuracy'])

это то, что у меня есть:

class ArcFace_Model():
    def __init__(self, input_shape, num_classes):
        self.input_shape = (input_shape,)
        self.num_classes = num_classes
        self.label_shape = (num_classes,)

    def build(self):
        #create model
        model = Sequential()

        #add model layers
        model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=self.input_shape))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(BatchNormalization())
        model.add(Dropout(0.5))
        model.add(Flatten())
        model.add(Dense(512, kernel_initializer='he_normal'))
        model.add(BatchNormalization())
        model.add(ArcFace(num_classes=self.num_classes))

        # loss and optimizer
        optimizer=Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)
        model.compile(loss=categorical_crossentropy,
                        optimizer=optimizer,
                        metrics=['accuracy'])
        return model

но у меня проблема, мой input_shape равен 128, но код ввод (28,28,1). Я делаю это, потому что хочу использовать ArcFace в своей модели. Я нашел слой класса на github .

Есть ли способ исправить это?

...