попробуйте построить функциональную модель API в классе, но поднять NotImplementedError - PullRequest
0 голосов
/ 04 октября 2019
class MyModel(Model):

    def __init__(self,num_classes=1):
        super(MyModel, self).__init__()

        self.conv1=Convolution2D(filters=8,kernel_size=8,padding='same')
        self.batch_norm1=BatchNormalization()
        self.activation1=Activation('relu')
        self.conv2=Convolution2D(filters=16,kernel_size=8,activation='relu',padding='same')
        self.batch_norm2=BatchNormalization()
        self.activation2=Activation('relu')
        self.MaxPooling2D=MaxPooling2D(pool_size =(2, 2))
        self.Flatten=Flatten()
        self.dense1=Dense(16,activation='relu')
        self.dense2=Dense(num_classes,kernel_regularizer=regularizers.l2(0.4))

    def call(self,inputs):

        x=self.conv1(inputs)
        x=self.batch_norm1(x)
        x=self.activation1(x)
        x=self.conv2(x)
        x=self.batch_norm2(x)
        x=self.activation2(x)
        x=self.MaxPooling2D(x)
        x=self.Flatten(x)
        x=self.dense1(x)
        return self.dense2(x)

    def compute_output_shape(self, input_shape):

        shape = tf.TensorShape(input_shape).as_list()
        shape[-1] = self.num_classes
        return tf.TensorShape(shape)

model=MyModel()

adam=Adam(learning_rate=1e-4)

model.compile(optimizer=adam,loss="mse")

earlystopper = EarlyStopping(monitor='val_loss', patience=20, verbose=0) 

checkpoint =ModelCheckpoint("C:/Users/user/Desktop/research/pic_recognition/cnn2d-model.hdf5",save_best_only=True)

callback_list=[earlystopper,checkpoint]  

model.fit(x_train, y_train, epochs=50, batch_size=8,validation_split=0.1,callbacks=callback_list)

но я получаю эту ошибку:

Файл "", строка 46, в model.fit (x_train, y_train, epochs = 50, batch_size = 8, validation_split = 0.1,callbacks = callback_list) Файл "D: \ Anaconda3 \ lib \ site-packages \ keras \ engine \ training.py", строка 1239, в соответствии с validation_freq = validation_freq) Файл "D: \ Anaconda3 \ lib \ site-packages \ keras \"engine \ training_arrays.py ", строка 216, в файле callbacks.on_epoch_end fit_loop (epoch, epoch_logs) Файл" D: \ Anaconda3 \ lib \ site-packages \ keras \ callbacks \ callbacks.py ", строка 152, в файле on_epoch_end callback.on_epoch_end(эпоха, журналы) Файл "D: \ Anaconda3 \ lib \ site-packages \ keras \ callbacks \ callbacks.py", строка 719, в on_epoch_end self.model.save (filepath, overwrite = True) Файл "D: \ Anaconda3\ lib \ site-packages \ keras \ engine \ network.py ", строка 1150, при повышении сохранения NotImplementedError NotImplementedError

1 Ответ

0 голосов
/ 10 октября 2019

Для пользовательских моделей вы должны использовать «save_weights_only = True» для ModelCheckpoint () или использовать model.save_weights ()

Для получения дополнительной информации см. Ссылки ниже:

  1. https://github.com/tensorflow/tensorflow/issues/22837

  2. https://github.com/keras-team/keras/issues/12922

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...