Так работает train_on_batch
, он рассчитывает потери, затем обновляет сеть, поэтому мы получаем потери до того, как сеть будет обновлена. Когда мы применяем predict_on_batch
, мы получаем прогноз из обновленной сети.
Под капотом train_on_batch выполняется много других вещей, таких как исправление типов данных, стандартизация данных и т. Д. c.
Ближайший брат train_on_batch
будет test_on_batch
. Если вы запустите test_on_batch
, вы обнаружите, что результат близок к train_on_bacth
, но не тот же.
Вот реализация test_on_batch
: https://github.com/tensorflow/tensorflow/blob/e5bf8de410005de06a7ff5393fafdf832ef1d4ad/tensorflow/python/keras/engine/training_v2_utils.py#L442
он внутренне вызывает _standardize_user_data
, чтобы исправить ваши типы данных, формы данных и т. д. c.
После того, как вы исправите свои x
и y
с правильными формами и типами данных, результат очень близок за исключением некоторой небольшой разницы delta
из-за числовой нестабильности.
Вот минимальный пример, где test_on_batch
, train_on_batch
и predict_on_batch
, кажется, согласуются с результатом численно.
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import tensorflow as tf
import numpy as np
# Loss definition
def mse(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true-y_pred))
# Model definition
model = Sequential()
model.add(Dense(1, input_shape = (10,)))
model.compile(optimizer = 'adam', loss = mse, metrics = [mse])
# Data creation
batch_size = 10
x = np.random.random_sample([batch_size,10]).astype('float32').reshape(-1, 10)
y = np.random.random_sample(batch_size).astype('float32').reshape(-1,1)
print(x.shape)
print(y.shape)
model.summary()
# running 5 iterations to check
for _ in range(5):
# Print loss before training
y_pred = model.predict_on_batch(x)
print("Before: " + str(mse(y,y_pred).numpy()))
# Print loss output from train_on_batch
print("Train output: " + str(model.train_on_batch(x,y)))
print(model.test_on_batch(x, y))
# Print loss after training
y_pred = model.predict_on_batch(x)
print("After: " + str(mse(y,y_pred).numpy()))
(10, 10)
(10, 1)
Model: "sequential_25"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_27 (Dense) (None, 1) 11
=================================================================
Total params: 11
Trainable params: 11
Non-trainable params: 0
_________________________________________________________________
Before: 0.30760005
Train output: [0.3076000511646271, 0.3076000511646271]
[0.3052913546562195, 0.3052913546562195]
After: 0.30529135
Before: 0.30529135
Train output: [0.3052913546562195, 0.3052913546562195]
[0.30304449796676636, 0.30304449796676636]
After: 0.3030445
Before: 0.3030445
Train output: [0.30304449796676636, 0.30304449796676636]
[0.3008604645729065, 0.3008604645729065]
After: 0.30086046
Before: 0.30086046
Train output: [0.3008604645729065, 0.3008604645729065]
[0.2987399995326996, 0.2987399995326996]
After: 0.29874
Before: 0.29874
Train output: [0.2987399995326996, 0.2987399995326996]
[0.2966836094856262, 0.2966836094856262]
After: 0.2966836
Примечание: train_on_batch
обновляет вес нейронной сети после расчета потерь, поэтому очевидно, что потери от train_on_batch
и test_on_batch
или predict_on_batch
не будут одинаковыми точно. Правильный вопрос может заключаться в том, почему test_on_batch
и predict_on_batch
дают разные потери с вашими данными.