Будучи новичком в DL и Tensorflow, я пытаюсь сослаться на https://machinetalk.org/2019/03/29/neural-machine-translation-with-attention-mechanism/?unapproved=67&moderation-hash=ea8e5dcb97c8236f68291788fbd746a7#comment-67 для реализации модели от последовательности к последовательности с вниманием.Тем не менее, я весьма озадачен тем, как можно реализовать раннюю остановку, чтобы избежать переобучения, поскольку в коде, которому я следую, не используется соглашение model.fit()
.
if MODE == 'train':
for e in range(NUM_EPOCHS):
en_initial_states = encoder.init_states(BATCH_SIZE)
encoder.save_weights(
'checkpoints_luong/encoder/encoder_{}.h5'.format(e + 1))
decoder.save_weights(
'checkpoints_luong/decoder/decoder_{}.h5'.format(e + 1))
for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
loss = train_step(source_seq, target_seq_in,
target_seq_out, en_initial_states)
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(
e + 1, batch, loss.numpy()))
try:
predict()
except Exception:
continue
Ниже приведен синтаксис для функции train_step()
: -
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out, en_initial_states):
loss = 0
with tf.GradientTape() as tape:
en_outputs = encoder(source_seq, en_initial_states)
en_states = en_outputs[1:]
de_state_h, de_state_c = en_states
# We need to create a loop to iterate through the target sequences
for i in range(target_seq_out.shape[1]):
# Input to the decoder must have shape of (batch_size, length)
# so we need to expand one dimension
decoder_in = tf.expand_dims(target_seq_in[:, i], 1)
logit, de_state_h, de_state_c, _ = decoder(
decoder_in, (de_state_h, de_state_c), en_outputs[0])
# The loss is now accumulated through the whole batch
loss += loss_func(target_seq_out[:, i], logit)
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss / target_seq_out.shape[1]
Чтобы реализовать раннюю остановку в этом случае, это тот же процесс, чтобы использовать callback
и т. Д.,или процесс отличается?Пример для этого случая высоко ценится.