Я построил модель CNN, используя принцип «Подкласса моделей» в Keras.Вот класс, который представляет мою модель:
class ConvNet(tf.keras.Model):
def __init__(self, data_format, classes):
super(ConvNet, self).__init__()
if data_format == "channels_first":
axis = 1
elif data_format == "channels_last":
axis = -1
self.conv_layer1 = tf.keras.layers.Conv2D(filters = 32, kernel_size = 3,strides = (1,1),
padding = "same",activation = "relu")
self.pool_layer1 = tf.keras.layers.MaxPooling2D(pool_size = (2,2), strides = (2,2))
self.conv_layer2 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 3,strides = (1,1),
padding = "same",activation = "relu")
self.pool_layer2 = tf.keras.layers.MaxPooling2D(pool_size = (2,2), strides = (2,2))
self.conv_layer3 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 5,strides = (1,1),
padding = "same",activation = "relu")
self.pool_layer3 = tf.keras.layers.MaxPooling2D(pool_size = (2,2), strides = (1,1),
padding = "same")
self.flatten = tf.keras.layers.Flatten()
self.dense_layer1 = tf.keras.layers.Dense(units = 512, activation = "relu")
self.dense_layer2 = tf.keras.layers.Dense(units = classes, activation = "softmax")
def call(self, inputs, training = True):
output_tensor = self.conv_layer1(inputs)
output_tensor = self.pool_layer1(output_tensor)
output_tensor = self.conv_layer2(output_tensor)
output_tensor = self.pool_layer2(output_tensor)
output_tensor = self.conv_layer3(output_tensor)
output_tensor = self.pool_layer3(output_tensor)
output_tensor = self.flatten(output_tensor)
output_tensor = self.dense_layer1(output_tensor)
return self.dense_layer2(output_tensor)
Я хотел бы знать, как обучать его «с нетерпением», и под этим я подразумеваю отказ от использования compile
и fit
методов.
Я не уверен, как именно построить тренировочный цикл.Я понимаю, что должен выполнить функцию tf.GradientTape.gradient()
, чтобы рассчитать градиенты, а затем использовать optimizers.apply_gradients()
для обновления параметров моей модели.
Что я не понимаю, так это как я могу делать прогнозы с моей модельючтобы получить logits
, а затем использовать их для расчета потерь.Если бы кто-то мог помочь мне с идеей о том, как построить тренировочный цикл, я был бы очень признателен.