Для этого не нужно создавать подкласс fit()
. Можно просто сделать индивидуальную тренировку l oop. Посмотрите, как я это сделал:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.keras import Model
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Concatenate
import tensorflow_datasets as tfds
from tensorflow.keras.regularizers import l1, l2, l1_l2
from collections import deque
dataset, info = tfds.load('mnist',
with_info=True,
split='train',
as_supervised=False)
TAKE = 1_000
data = dataset.map(lambda x: (tf.cast(x['image'],
tf.float32), x['label'])).shuffle(TAKE).take(TAKE)
len_train = int(8e-1*TAKE)
train = data.take(len_train).batch(8)
test = data.skip(len_train).take(info.splits['train'].num_examples - len_train).batch(8)
class CNN(Model):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = Dense(32, activation=tf.nn.relu,
kernel_regularizer=l1(1e-2),
input_shape=info.features['image'].shape)
self.layer2 = Conv2D(filters=16,
kernel_size=(3, 3),
strides=(1, 1),
activation='relu',
input_shape=info.features['image'].shape)
self.layer3 = MaxPooling2D(pool_size=(2, 2))
self.layer4 = Conv2D(filters=32,
kernel_size=(3, 3),
strides=(1, 1),
activation=tf.nn.elu,
kernel_initializer=tf.keras.initializers.glorot_normal)
self.layer5 = MaxPooling2D(pool_size=(2, 2))
self.layer6 = Flatten()
self.layer7 = Dense(units=64,
activation=tf.nn.relu,
kernel_regularizer=l2(1e-2))
self.layer8 = Dense(units=64,
activation=tf.nn.relu,
kernel_regularizer=l1_l2(l1=1e-2, l2=1e-2))
self.layer9 = Concatenate()
self.layer10 = Dense(units=info.features['label'].num_classes)
def call(self, inputs, training=None, **kwargs):
b = self.layer1(inputs)
a = self.layer2(inputs)
a = self.layer3(a)
a = self.layer4(a)
a = self.layer5(a)
a = self.layer6(a)
a = self.layer8(a)
b = self.layer7(b)
b = self.layer6(b)
x = self.layer9([a, b])
x = self.layer10(x)
return x
cnn = CNN()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()
optimizer = tf.keras.optimizers.Nadam()
template = 'Epoch {:3} Train Loss {:7.4f} Test Loss {:7.4f} ' \
'Train Acc {:6.2%} Test Acc {:6.2%} '
epochs = 5
early_stop = epochs//50
loss_hist = deque()
acc_hist = deque(maxlen=1)
acc_hist.append(0)
for epoch in range(1, epochs + 1):
train_loss.reset_states()
test_loss.reset_states()
train_acc.reset_states()
test_acc.reset_states()
for images, labels in train:
with tf.GradientTape() as tape:
logits = cnn(images, training=True)
loss = loss_object(labels, logits)
train_loss(loss)
train_acc(labels, logits)
current_acc = tf.metrics.SparseCategoricalAccuracy()(labels, logits)
if tf.greater(current_acc, acc_hist[-1]):
print('IMPROVEMENT.')
gradients = tape.gradient(loss, cnn.trainable_variables)
optimizer.apply_gradients(zip(gradients, cnn.trainable_variables))
acc_hist.append(current_acc)
for images, labels in test:
logits = cnn(images, training=False)
loss = loss_object(labels, logits)
test_loss(loss)
test_acc(labels, logits)
print(template.format(epoch,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result()))
if len(loss_hist) > early_stop and loss_hist.popleft() < min(loss_hist):
print('Early stopping. No validation loss decrease in %i epochs.' % early_stop)
break
Вывод:
IMPROVEMENT.
IMPROVEMENT.
IMPROVEMENT.
IMPROVEMENT.
Epoch 1 Train Loss 21.1698 Test Loss 21.3391 Train Acc 37.13% Test Acc 38.50%
IMPROVEMENT.
IMPROVEMENT.
IMPROVEMENT.
Epoch 2 Train Loss 13.8314 Test Loss 12.2496 Train Acc 50.88% Test Acc 52.50%
Epoch 3 Train Loss 13.7594 Test Loss 12.5884 Train Acc 51.75% Test Acc 53.00%
Epoch 4 Train Loss 13.1418 Test Loss 13.2374 Train Acc 52.75% Test Acc 51.50%
Epoch 5 Train Loss 13.6471 Test Loss 13.3157 Train Acc 49.63% Test Acc 51.50%
Вот часть, которая выполнила эту работу. Это deque
, и он пропускает применение градиентов, если последний элемент deque
меньше.
for images, labels in train:
with tf.GradientTape() as tape:
logits = cnn(images, training=True)
loss = loss_object(labels, logits)
train_loss(loss)
train_acc(labels, logits)
current_acc = tf.metrics.SparseCategoricalAccuracy()(labels, logits)
if tf.greater(current_acc, acc_hist[-1]):
print('IMPROVEMENT.')
gradients = tape.gradient(loss, cnn.trainable_variables)
optimizer.apply_gradients(zip(gradients, cnn.trainable_variables))
acc_hist.append(current_acc)